
import torch
import torch.nn as nn
import torch.nn.functional as F

class ICPRegistration(nn.Module):
    def __init__(self, num_correspondences, max_iterations=50, tolerance=1e-6):
        """

        """
        super(ICPRegistration, self).__init__()
        self.num_correspondences = num_correspondences  # 
        self.max_iterations = max_iterations  #
        self.tolerance = tolerance  #

    def forward(self, ref_points, src_points):
        """

        """
        for i in range(self.max_iterations):
            # 1. 
            src_corr_indices, ref_corr_indices, node_corr_scores = self.find_closest_points(ref_points, src_points)

            # 2.
            rotation, translation = self.compute_transform(ref_points[ref_corr_indices], src_points[src_corr_indices])

            # 3.
            src_points = torch.matmul(src_points, rotation.T) + translation

            # 4.
            if torch.norm(translation) < self.tolerance:
                break

        return ref_corr_indices, src_corr_indices, node_corr_scores

    def find_closest_points(self, ref_points, src_points):
        """
        
        """
        dist_matrix = self.pairwise_distance(ref_points, src_points)  # 
        
        # 
        node_corr_scores, ref_corr_indices = torch.min(dist_matrix, dim=1)
        _, src_corr_indices = torch.min(dist_matrix, dim=0)

        return ref_corr_indices, src_corr_indices, node_corr_scores

    def compute_transform(self, ref_points, src_points):
        """
        
        """
        #
        ref_centroid = ref_points.mean(dim=0)
        src_centroid = src_points.mean(dim=0)

        # 
        ref_centered = ref_points - ref_centroid
        src_centered = src_points - src_centroid

        # 
        H = torch.matmul(src_centered.T, ref_centered)  # [3, N] x [3, N] -> [3, 3]
        U, S, V = torch.svd(H)

        # 
        rotation = torch.matmul(V, U.T)

        # 
        translation = ref_centroid - torch.matmul(src_centroid, rotation)

        return rotation, translation

    def pairwise_distance(self, x, y):
        """
        
        """
        x2 = torch.sum(x ** 2, dim=1, keepdim=True)  # [M, 1]
        y2 = torch.sum(y ** 2, dim=1, keepdim=True)  # [N, 1]
        dist = x2 + y2.T - 2 * torch.matmul(x, y.T)  # [M, N]
        return dist.clamp(min=0.0)  # 


class GaussianCoarseRegistration(nn.Module):
    def __init__(self, num_correspondences):
        super(GaussianCoarseRegistration, self).__init__()
        self.num_correspondences = num_correspondences  # 

        # 
        self.icp_registration = ICPRegistration(num_correspondences)

    def forward(self, ref_gs_params, src_gs_params):
        """
        
        """
        # 
        min_size = min(ref_gs_params.shape[0], src_gs_params.shape[0])
        ref_gs_params = ref_gs_params[:min_size]
        src_gs_params = src_gs_params[:min_size]

        ref_mu = ref_gs_params[:, :3]  # [M, 3] - 
        src_mu = src_gs_params[:, :3]  # [M, 3] - 

        # 
        ref_corr_indices, src_corr_indices, corr_scores = self.icp_registration(ref_mu, src_mu)

        return ref_corr_indices, src_corr_indices, corr_scores
    

class CoarseMatchFusion(nn.Module):
    def __init__(self, length):
        super(CoarseMatchFusion, self).__init__()
        self.weight_1 = 0.01
        self.weight_2 = 0.99
        self.length = length

    def forward(self, 
                ref_node_corr_indices_1, src_node_corr_indices_1, 
                ref_node_corr_indices_2, src_node_corr_indices_2, 
                node_corr_scores_1, node_corr_scores_2):
            # 
            top_k_1 = min(len(node_corr_scores_1), self.length)
            top_k_2 = min(len(node_corr_scores_2), self.length)

            topk_indices_1 = torch.topk(node_corr_scores_1, k=top_k_1, largest=True).indices
            topk_indices_2 = torch.topk(node_corr_scores_2, k=top_k_2, largest=True).indices

            ref_node_corr_indices_1_topk = ref_node_corr_indices_1[topk_indices_1]
            src_node_corr_indices_1_topk = src_node_corr_indices_1[topk_indices_1]
            node_corr_scores_1_topk = node_corr_scores_1[topk_indices_1]

            ref_node_corr_indices_2_topk = ref_node_corr_indices_2[topk_indices_2]
            src_node_corr_indices_2_topk = src_node_corr_indices_2[topk_indices_2]
            node_corr_scores_2_topk = node_corr_scores_2[topk_indices_2]

            # Ensure both tensors have the same size for combining
            if node_corr_scores_1_topk.size(0) > node_corr_scores_2_topk.size(0):
                # Pad the smaller tensor (node_corr_scores_2_topk) to match the size of node_corr_scores_1_topk
                padding_size = node_corr_scores_1_topk.size(0) - node_corr_scores_2_topk.size(0)
                node_corr_scores_2_topk = F.pad(node_corr_scores_2_topk, (0, padding_size), value=0)
                ref_node_corr_indices_2_topk = F.pad(ref_node_corr_indices_2_topk, (0, padding_size), value=-1)  # Padding with -1 (or another invalid index)
                src_node_corr_indices_2_topk = F.pad(src_node_corr_indices_2_topk, (0, padding_size), value=-1)  # Padding with -1 (or another invalid index)
            elif node_corr_scores_1_topk.size(0) < node_corr_scores_2_topk.size(0):
                # Pad the smaller tensor (node_corr_scores_1_topk) to match the size of node_corr_scores_2_topk
                padding_size = node_corr_scores_2_topk.size(0) - node_corr_scores_1_topk.size(0)
                node_corr_scores_1_topk = F.pad(node_corr_scores_1_topk, (0, padding_size), value=0)
                ref_node_corr_indices_1_topk = F.pad(ref_node_corr_indices_1_topk, (0, padding_size), value=-1)  # Padding with -1 (or another invalid index)
                src_node_corr_indices_1_topk = F.pad(src_node_corr_indices_1_topk, (0, padding_size), value=-1)  # Padding with -1 (or another invalid index)

            # 
            node_corr_scores_combined = self.weight_1 * node_corr_scores_1_topk + self.weight_2 * node_corr_scores_2_topk

            # 
            combined_ref_node_corr_indices = torch.cat([ref_node_corr_indices_1_topk, ref_node_corr_indices_2_topk])
            combined_src_node_corr_indices = torch.cat([src_node_corr_indices_1_topk, src_node_corr_indices_2_topk])
            combined_node_corr_scores = torch.cat([node_corr_scores_1_topk, node_corr_scores_2_topk])

            # 
            top_k_combined = min(len(combined_node_corr_scores), 256)
            topk_indices_combined = torch.topk(node_corr_scores_combined, k=top_k_combined, largest=True).indices

            # 
            ref_node_corr_indices = combined_ref_node_corr_indices[topk_indices_combined]
            src_node_corr_indices = combined_src_node_corr_indices[topk_indices_combined]
            node_corr_scores = combined_node_corr_scores[topk_indices_combined]
            
            # 
            return ref_node_corr_indices, src_node_corr_indices, node_corr_scores

